#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Tests a range of different schemes
#

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect

#requires: all_tests

verbose=False

from numpy import sqrt, max, abs, mean, array, log, pi

from os.path import join

import sys

try:
    import matplotlib.pyplot as plt
except:
    plt=None



print("Making MMS time integration test")
shell_safe("make > make.log")

# List of options to be passed for each test
if "only_petsc" in sys.argv:
    # this requires petsc:
    options = [("solver:type=imexbdf2 -snes_mf", "IMEX-BDF2", "-+",2)]
else:
    options = [
        ("solver:type=euler", "Euler", "-^",1)
        ,("solver:type=rk3ssp", "RK3-SSP", "-s",3)
        ,("solver:type=rk4", "RK4", "-o",4)
        ,("solver:type=rkgeneric", "RK-generic", "-o",4)
        ,("solver:type=splitrk", "SplitRK", "-x", 2)
    ]

#Missing: cvode, ida, slepc, power, arkode, snes
# Is there a better way to check a certain solver should be enabled?

# List of NX values to use
timesteps = [4, 8, 16, 32, 64, 128, 256]

nproc = 1

success = True

err_2_all = []
err_inf_all = []
for opts,label,sym,expected_order in options:
    error_2   = []  # The L2 error (RMS)
    error_inf = []  # The maximum error

    for nt in timesteps:
        args = " solver:timestep="+str(1./nt) + " " + opts

        if verbose:
            print("  Running with " + args)

        # Delete old data
        shell("rm data/BOUT.dmp.*.nc")
        
        # Command to run
        cmd = "./time "+args
        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        
        # Save output to log file
        f = open("run.log."+label+"."+str(nt), "w")
        f.write(out)
        f.close()
        
        # Collect data
        E_f = collect("E_f", tind=[1,1], info=False, path="data")
        # Average error over domain
        l2 = sqrt(mean(E_f**2))
        linf = max(abs( E_f ))
        
        error_2.append( l2 )
        error_inf.append( linf )

        if verbose:
            print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))
        
    
    # Append to list of all results
    err_2_all.append( (error_2, label, sym) )
    err_inf_all.append( (error_inf, label, sym) )

    # Calculate grid spacing
    dt = 1. / array(timesteps)

    # Calculate convergence order
    
    order = log(error_2[-1] / error_2[-2]) / log(dt[-1] / dt[-2])
    print("Convergence order = %f (expected: %f) %s" % (order,expected_order,label))
    
    if  expected_order - 0.2 < order < expected_order + 0.2:
        pass
    else:
        success = False
        print(" -> FAILED")

    # plot errors
    if False:
        plt.figure()
    
        plt.plot(dt, error_2, '-o', label=r'$l^2$')
        plt.plot(dt, error_inf, '-x', label=r'$l^\infty$')
    
        plt.plot(dt, error_2[-1]*(dt/dt[-1])**order, '--', label="Order %.1f"%(order))
        
        plt.legend(loc="lower right")
        plt.grid()
    
        plt.yscale('log')
        plt.xscale('log')
        
        plt.xlabel(r'Time step $\delta t$')
        plt.ylabel("Error norm")
    
        #plt.savefig("norm.pdf")

        plt.show()
        plt.close()

try:
    # Plot all results for comparison
    plt.figure()
    for e, label, sym in err_2_all:
        plt.plot(dt, e, sym, label=label)

    plt.legend(loc="lower right")
    plt.grid()
    plt.yscale('log')
    plt.xscale('log')
    
    plt.xlabel(r'Time step $\delta t$')
    plt.ylabel(r'$l^2$ error norm')
    plt.savefig("time_norm.pdf")
    plt.close()
except:
    pass

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
